{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| [08_machine_translation/01_从头实现Seq2Seq模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/08_machine_translation/01_从头实现Seq2Seq模型.ipynb)  | 从头实现Seq2Seq翻译模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/08_machine_translation/01_从头实现Seq2Seq模型.ipynb) |\n",
    "\n",
    "\n",
    "# 用基于注意力机制的seq2seq神经网络进行翻译\n",
    "\n",
    "\n",
    "**作者**: [Sean Robertson](https://github.com/spro/practical-pytorch)\n",
    "\n",
    "这个教程主要讲解用一个神经网络将法语翻译成英语.\n",
    "\n",
    "```\n",
    "\n",
    "    [KEY: > input, = target, < output]\n",
    "\n",
    "    > il est en train de peindre un tableau .\n",
    "    = he is painting a picture .\n",
    "    < he is painting a picture .\n",
    "\n",
    "    > pourquoi ne pas essayer ce vin delicieux ?\n",
    "    = why not try that delicious wine ?\n",
    "    < why not try that delicious wine ?\n",
    "\n",
    "    > elle n est pas poete mais romanciere .\n",
    "    = she is not a poet but a novelist .\n",
    "    < she not not a poet but a novelist .\n",
    "\n",
    "    > vous etes trop maigre .\n",
    "    = you re too skinny .\n",
    "    < you re all alone .\n",
    "```\n",
    "\n",
    "\n",
    "这是通过`seq2seq网络`  [http://arxiv.org/abs/1409.3215](http://arxiv.org/abs/1409.3215) 实现的简单却强大的想法,\n",
    "通过两个递归神经网络一起工作实现将一个序列转换为另一个.一个编码器网络将输入序列压\n",
    "缩成向量,解码器网络将该矢量展开为新的序列.\n",
    "![seq2seq](../docs/seq2seq.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "为了改进这个模型,我们将使用一种`注意力机制` [https://arxiv.org/abs/1409.0473](https://arxiv.org/abs/1409.0473),\n",
    "它可以让解码器学习将注意力集中在输入序列的特定范围上.\n",
    "\n",
    "**推荐阅读:**\n",
    "\n",
    "我们假设你至少已经安装了PyTorch,了解Python,并且了解张量:\n",
    "\n",
    "-  http://pytorch.org/ PyTorch安装说明\n",
    "-  :doc:`/beginner/deep_learning_60min_blitz` 开始使用PyTorch\n",
    "-  :doc:`/beginner/pytorch_with_examples` 进行广泛而深入的了解\n",
    "-  :doc:`/beginner/former_torchies_tutorial` 如果你是前Lua Torch用户\n",
    "\n",
    "\n",
    "这些内容也有利于了解seq2seq网络和其工作机制:\n",
    "\n",
    "-  `用RNN编码器 - 解码器来学习用于统计机器翻译的短语表示 <http://arxiv.org/abs/1406.1078>`\n",
    "-  `用神经网络进行seq2seq学习 <http://arxiv.org/abs/1409.3215>`\n",
    "-  `神经网络机器翻译联合学习对齐和翻译 <https://arxiv.org/abs/1409.0473>`\n",
    "-  `神经会话模型 <http://arxiv.org/abs/1506.05869>`\n",
    "\n",
    "你还可以找到以前的教程关于Character-Level RNN名称分类 [04_text_classification/04_应用_姓名识别国籍.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/04_text_classification/04_应用_姓名识别国籍.ipynb)\n",
    "和生成名称 [06_text_generation/01_字符级人名生成.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/06_text_generation/01_字符级人名生成.ipynb)\n",
    "这些概念与编码器和解码器模型非常相似.\n",
    "\n",
    "更多内容请阅读介绍这些主题的论文:\n",
    "\n",
    "-  `用RNN编码器 - 解码器来学习用于统计机器翻译的短语表示 <http://arxiv.org/abs/1406.1078>`\n",
    "-  `用神经网络进行seq2seq学习 <http://arxiv.org/abs/1409.3215>`\n",
    "-  `神经网络机器翻译联合学习对齐和翻译 <https://arxiv.org/abs/1409.0473>`\n",
    "-  `神经会话模型 <http://arxiv.org/abs/1506.05869>`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import open\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "use_cuda = torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据文件\n",
    "\n",
    "\n",
    "这个项目的数据是一组数以千计的英语到法语的翻译对。\n",
    "\n",
    "这个问题在 [Open Data Stack Exchange](http://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages)上\n",
    "指导我们使用开放的翻译网站 http://tatoeba.org/ 可下载地址为 http://tatoeba.org/eng/downloads - 更好的是,\n",
    "有人做了额外的工作,切分语言对到单个文本文件中: http://www.manythings.org/anki/\n",
    "\n",
    "英文到法文对太大而不能包含在repo中,因此开始前请下载\n",
    " `data/eng-fra.txt`. 该文件是一个制表符分隔的翻译对列表: \n",
    "\n",
    "```\n",
    "    I am cold.    Je suis froid.\n",
    "```\n",
    "\n",
    "PS:\n",
    "\n",
    "下载数据文件[https://download.pytorch.org/tutorial/data.zip](https://download.pytorch.org/tutorial/data.zip) ，并解压到正确的路径。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "与character-level RNN教程中使用的字符编码类似,我们将用语言中的每个单词\n",
    "作为独热向量,或者除了单个单词之外(在单词的索引处)的大的零向量. 相较于可能\n",
    "存在于一种语言中仅有十个字符相比,多数都是有大量的字,因此编码向量很大. \n",
    "然而,我们会欺骗性的做一些数据修剪,保证每种语言只使用几千字.\n",
    "\n",
    "![word-encoding.png](../docs/word-encoding@2x.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们需要每个单词对应唯一的索引作为稍后的网络输入和目标.为了追踪这些索引我们使用一个帮助类\n",
    " ``Lang`` 类中有 词 → 索引 (``word2index``) 和 索引 → 词\n",
    "(``index2word``) 的字典, 以及每个词``word2count`` 用来替换稀疏词汇."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这些文件全部采用Unicode编码,为了简化我们将Unicode字符转换为ASCII,\n",
    "使所有内容小写,并修剪大部分标点符号.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 感谢将Unicode字符串转换为纯ASCII\n",
    "# http://stackoverflow.com/a/518232/2809427\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "# 小写,修剪和删除非字母字符\n",
    "\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
    "    s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\n",
    "    return s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要读取数据文件,我们将把文件分成行,然后将行成对分开.\n",
    "这些文件都是英文→其他语言,所以如果我们想从其他语言翻译→英文,我们添加了\n",
    "翻转标志 ``reverse``来翻转词语对.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readLangs(lang1, lang2, reverse=False):\n",
    "    print(\"Reading lines...\")\n",
    "\n",
    "    # 读取文件并按行分开\n",
    "    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n",
    "        read().strip().split('\\n')\n",
    "\n",
    "    # 将每一行分成两列并进行标准化\n",
    "    pairs = [[normalizeString(s) for s in l.split('\\t')] for l in lines]\n",
    "\n",
    "    # 翻转对,Lang实例化\n",
    "    if reverse:\n",
    "        pairs = [list(reversed(p)) for p in pairs]\n",
    "        input_lang = Lang(lang2)\n",
    "        output_lang = Lang(lang1)\n",
    "    else:\n",
    "        input_lang = Lang(lang1)\n",
    "        output_lang = Lang(lang2)\n",
    "\n",
    "    return input_lang, output_lang, pairs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于有很多例句,我们希望快速训练,我们会将数据集裁剪为相对简短的句子. \n",
    "这里的单词的最大长度是10词(包括结束标点符号),我们正在过滤到翻译\n",
    "成\"I am\"或\"He is\"等形式的句子.(考虑到先前替换了撇号).\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_LENGTH = 10\n",
    "\n",
    "eng_prefixes = (\n",
    "    \"i am \", \"i m \",\n",
    "    \"he is\", \"he s \",\n",
    "    \"she is\", \"she s\",\n",
    "    \"you are\", \"you re \",\n",
    "    \"we are\", \"we re \",\n",
    "    \"they are\", \"they re \"\n",
    ")\n",
    "\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH and \\\n",
    "        p[1].startswith(eng_prefixes)\n",
    "\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完整的准备数据的过程:\n",
    "\n",
    "-  加载文本文件切分成行,并切分成单词对:\n",
    "-  文本归一化, 按照长度和内容过滤\n",
    "-  从成对的句子中制作单词列表\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading lines...\n",
      "Read 119656 sentence pairs\n",
      "Trimmed to 10835 sentence pairs\n",
      "Counting words...\n",
      "Counted words:\n",
      "fra 4472\n",
      "eng 2908\n",
      "['tu abuses de ton autorite .', 'you are abusing your authority .']\n"
     ]
    }
   ],
   "source": [
    "def prepareData(lang1, lang2, reverse=False):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
    "    print(\"Read %s sentence pairs\" % len(pairs))\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
    "    print(\"Counting words...\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(\"Counted words:\")\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('eng', 'fra', True)\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Seq2Seq模型\n",
    "\n",
    "\n",
    "递归神经网络(RNN),是一个按照一个序列进行操作的网路,并\n",
    "将其自己的输出用作后续步骤的输入.\n",
    "\n",
    "一个 `序列到序列网络` [http://arxiv.org/abs/1409.3215](http://arxiv.org/abs/1409.3215), 或\n",
    "seq2seq 网络, 或 `编码解码器网络` [https://arxiv.org/pdf/1406.1078v3.pdf](https://arxiv.org/pdf/1406.1078v3.pdf), \n",
    "是由两个称为编码器和解码器的RNN组成的模型. 编码器读取输入序列并输出单个向量,\n",
    "解码器读取该向量以产生输出序列.\n",
    "\n",
    "\n",
    "与单个RNN的序列预测不同,每个输入对应一个输出,\n",
    "seq2seq模型将我们从序列长度和顺序中解放出来,\n",
    "这使得它成为两种语言之间翻译的理想选择.\n",
    "\n",
    "考虑这句话 \"Je ne suis pas le chat noir\" → \"I am not the\n",
    "black cat\".  输入句子中的大部分单词在输出句子中有直接翻译,\n",
    "但顺序略有不同,例如: \"chat noir\" 和 \"black cat\". 由于 \n",
    "\"ne/pas\"结构, 其中另一个单词在输入的句子中. \n",
    "直接从输入词的序列中直接生成正确的翻译是很困难的.\n",
    "\n",
    "使用seq2seq模型,编码器会创建一个单独的向量,\n",
    "在理想情况下,它将输入序列的\"含义\"编码为单个向量 - 句子的N维空间中的一个点."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 编码器\n",
    "\n",
    "\n",
    "seq2seq网络的编码器是一个RNN,它为输入句子中的每个单词输出一些值.\n",
    "对于每个输入字,编码器输出一个向量和一个隐藏状态,并将隐藏状态用于下一个输入字.\n",
    "\n",
    "![encoder](../docs/encoder-network.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(EncoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.embedding = nn.Embedding(input_size, hidden_size)\n",
    "        self.gru = nn.GRU(hidden_size, hidden_size)\n",
    "\n",
    "    def forward(self, input, hidden):\n",
    "        embedded = self.embedding(input).view(1, 1, -1)\n",
    "        output = embedded\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        return output, hidden\n",
    "\n",
    "    def initHidden(self):\n",
    "        result = Variable(torch.zeros(1, 1, self.hidden_size))\n",
    "        if use_cuda:\n",
    "            return result.cuda()\n",
    "        else:\n",
    "            return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 解码器\n",
    "\n",
    "\n",
    "解码器是另一个RNN,它接收编码器输出向量并输出一个单词序列来创建翻译.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 简单的解码器\n",
    "\n",
    "\n",
    "在最简单的seq2seq解码器中,我们只使用编码器的最后一个输出.\n",
    "这个最后的输出有时称为上下文向量,因为它从整个序列编码上下文.\n",
    "该上下文向量被用作解码器的初始隐藏状态.\n",
    "\n",
    "在解码的每一步,解码器都被赋予一个输入指令和隐藏状态.\n",
    "初始输入指令字符串开始的``<SOS>``指令,第一个隐藏状态是上下文向量(编码器的最后隐藏状态).\n",
    "\n",
    "![docoder](../docs/decoder-network.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderRNN(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size):\n",
    "        super(DecoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.embedding = nn.Embedding(output_size, hidden_size)\n",
    "        self.gru = nn.GRU(hidden_size, hidden_size)\n",
    "        self.out = nn.Linear(hidden_size, output_size)\n",
    "        self.softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "    def forward(self, input, hidden):\n",
    "        output = self.embedding(input).view(1, 1, -1)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.softmax(self.out(output[0]))\n",
    "        return output, hidden\n",
    "\n",
    "    def initHidden(self):\n",
    "        result = Variable(torch.zeros(1, 1, self.hidden_size))\n",
    "        if use_cuda:\n",
    "            return result.cuda()\n",
    "        else:\n",
    "            return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们鼓励你训练和观察这个模型的结果,但为了节省空间,我们将直接进正题引入注意力机制.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 注意力解码器\n",
    "\n",
    "\n",
    "如果仅在编码器和解码器之间传递上下文向量,则该单个向量承担编码整个句子的负担.\n",
    "\n",
    "注意力允许解码器网络针对解码器自身输出的每一步\"聚焦\"编码器输出的不同部分. \n",
    "首先我们计算一组注意力权重. 这些将被乘以编码器输出矢量获得加权的组合. \n",
    "结果(在代码中为``attn_applied``) 应该包含关于输入序列的特定部分的信息,\n",
    "从而帮助解码器选择正确的输出单词.\n",
    "\n",
    "![](https://i.imgur.com/1152PYf.png)\n",
    "\n",
    "\n",
    "使用解码器的输入和隐藏状态作为输入,利用另一个前馈层 ``attn``计算注意力权重, \n",
    "由于训练数据中有各种大小的句子,为了实际创建和训练此层,\n",
    "我们必须选择最大长度的句子(输入长度,用于编码器输出),以适用于此层.\n",
    "最大长度的句子将使用所有注意力权重,而较短的句子只使用前几个.\n",
    "\n",
    "![attention-decoder-network.png](../docs/attention-decoder-network.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttnDecoderRNN(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n",
    "        super(AttnDecoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.output_size = output_size\n",
    "        self.dropout_p = dropout_p\n",
    "        self.max_length = max_length\n",
    "\n",
    "        self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n",
    "        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n",
    "        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n",
    "        self.dropout = nn.Dropout(self.dropout_p)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n",
    "        self.out = nn.Linear(self.hidden_size, self.output_size)\n",
    "\n",
    "    def forward(self, input, hidden, encoder_outputs):\n",
    "        embedded = self.embedding(input).view(1, 1, -1)\n",
    "        embedded = self.dropout(embedded)\n",
    "\n",
    "        attn_weights = F.softmax(\n",
    "            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n",
    "        attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
    "                                 encoder_outputs.unsqueeze(0))\n",
    "\n",
    "        output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
    "        output = self.attn_combine(output).unsqueeze(0)\n",
    "\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "\n",
    "        output = F.log_softmax(self.out(output[0]), dim=1)\n",
    "        return output, hidden, attn_weights\n",
    "\n",
    "    def initHidden(self):\n",
    "        result = Variable(torch.zeros(1, 1, self.hidden_size))\n",
    "        if use_cuda:\n",
    "            return result.cuda()\n",
    "        else:\n",
    "            return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "还有其他形式的注意力通过使用相对位置方法来解决长度限制. 阅读关于 \"local attention\" 在 `基于注意力的神经机器翻译的有效途径` [https://arxiv.org/abs/1508.04025](https://arxiv.org/abs/1508.04025)\n",
    "\n",
    "## 训练\n",
    "\n",
    "\n",
    "### 准备训练数据\n",
    "\n",
    "为了训练,对于每一对我们将需要输入的张量(输入句子中的词的索引)和\n",
    " 目标张量(目标语句中的词的索引). 在创建这些向量时,我们会将EOS标记添加到两个序列中.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def indexesFromSentence(lang, sentence):\n",
    "    return [lang.word2index[word] for word in sentence.split(' ')]\n",
    "\n",
    "\n",
    "def variableFromSentence(lang, sentence):\n",
    "    indexes = indexesFromSentence(lang, sentence)\n",
    "    indexes.append(EOS_token)\n",
    "    result = Variable(torch.LongTensor(indexes).view(-1, 1))\n",
    "    if use_cuda:\n",
    "        return result.cuda()\n",
    "    else:\n",
    "        return result\n",
    "\n",
    "\n",
    "def variablesFromPair(pair):\n",
    "    input_variable = variableFromSentence(input_lang, pair[0])\n",
    "    target_variable = variableFromSentence(output_lang, pair[1])\n",
    "    return (input_variable, target_variable)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型\n",
    "\n",
    "\n",
    "为了训练我们通过编码器运行输入句子,并跟踪每个输出和最新的隐藏状态.\n",
    "然后解码器被赋予``<SOS>`` 指令作为其第一个输入,\n",
    "并将编码器的最后一个隐藏状态作为其第一个隐藏状态.\n",
    "\n",
    "\"Teacher forcing\" 是将实际目标输出用作每个下一个输入的概念,而不是将解码器的\n",
    "猜测用作下一个输入.使用教师强迫会使其更快地收敛,但是 `当训练好的网络被利用时,它可能表现出不稳定性.` [http://minds.jacobs-university.de/sites/default/files/uploads/papers/ESNTutorialRev.pdf](http://minds.jacobs-university.de/sites/default/files/uploads/papers/ESNTutorialRev.pdf).\n",
    "\n",
    "你可以观察教师强迫网络的输出,这些网络是用连贯的语法阅读的,但却远离了正确的翻译 - \n",
    "直观地来看它已经学会了代表输出语法,并且一旦老师告诉它前几个单词,就可以\"拾取\"它的意思,\n",
    " 但它没有适当地学会如何从翻译中创建句子.\n",
    "\n",
    "由于PyTorch的autograd给我们的自由,我们可以随意选择使用老师强制或不使用简单的if语句. \n",
    "打开``teacher_forcing_ratio``更多的使用它.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_forcing_ratio = 0.5\n",
    "\n",
    "\n",
    "def train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\n",
    "    encoder_hidden = encoder.initHidden()\n",
    "\n",
    "    encoder_optimizer.zero_grad()\n",
    "    decoder_optimizer.zero_grad()\n",
    "\n",
    "    input_length = input_variable.size()[0]\n",
    "    target_length = target_variable.size()[0]\n",
    "\n",
    "    encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size))\n",
    "    encoder_outputs = encoder_outputs.cuda() if use_cuda else encoder_outputs\n",
    "\n",
    "    loss = 0\n",
    "\n",
    "    for ei in range(input_length):\n",
    "        encoder_output, encoder_hidden = encoder(\n",
    "            input_variable[ei], encoder_hidden)\n",
    "        encoder_outputs[ei] = encoder_output[0][0]\n",
    "\n",
    "    decoder_input = Variable(torch.LongTensor([[SOS_token]]))\n",
    "    decoder_input = decoder_input.cuda() if use_cuda else decoder_input\n",
    "\n",
    "    decoder_hidden = encoder_hidden\n",
    "\n",
    "    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n",
    "\n",
    "    if use_teacher_forcing:\n",
    "        # 教师强制: 将目标作为下一个输入\n",
    "        for di in range(target_length):\n",
    "            decoder_output, decoder_hidden, decoder_attention = decoder(\n",
    "                decoder_input, decoder_hidden, encoder_outputs)\n",
    "            loss += criterion(decoder_output, target_variable[di])\n",
    "            decoder_input = target_variable[di]  # Teacher forcing\n",
    "\n",
    "    else:\n",
    "        # 没有教师强迫: 使用自己的预测作为下一个输入\n",
    "        for di in range(target_length):\n",
    "            decoder_output, decoder_hidden, decoder_attention = decoder(\n",
    "                decoder_input, decoder_hidden, encoder_outputs)\n",
    "            topv, topi = decoder_output.data.topk(1)\n",
    "            ni = topi[0][0]\n",
    "\n",
    "            decoder_input = Variable(torch.LongTensor([[ni]]))\n",
    "            decoder_input = decoder_input.cuda() if use_cuda else decoder_input\n",
    "\n",
    "            loss += criterion(decoder_output, target_variable[di])\n",
    "            if ni == EOS_token:\n",
    "                break\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    encoder_optimizer.step()\n",
    "    decoder_optimizer.step()\n",
    "\n",
    "    return loss.item() / target_length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据当前时间和进度百分比,这是一个帮助功能,用于打印经过的时间和估计的剩余时间.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import math\n",
    "\n",
    "\n",
    "def asMinutes(s):\n",
    "    m = math.floor(s / 60)\n",
    "    s -= m * 60\n",
    "    return '%dm %ds' % (m, s)\n",
    "\n",
    "\n",
    "def timeSince(since, percent):\n",
    "    now = time.time()\n",
    "    s = now - since\n",
    "    es = s / (percent)\n",
    "    rs = es - s\n",
    "    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整个训练过程如下所示:\n",
    "\n",
    "-  启动一个计时器\n",
    "-  初始化优化器和标准\n",
    "-  创建一组训练对\n",
    "-  为绘图建空损失数组\n",
    "\n",
    "然后我们多次调用``train``,偶尔打印进度(样本的百分比,到目前为止的时间,估计的时间)和平均损失. \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):\n",
    "    start = time.time()\n",
    "    plot_losses = []\n",
    "    print_loss_total = 0  # Reset every print_every\n",
    "    plot_loss_total = 0  # Reset every plot_every\n",
    "\n",
    "    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n",
    "    training_pairs = [variablesFromPair(random.choice(pairs))\n",
    "                      for i in range(n_iters)]\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    for iter in range(1, n_iters + 1):\n",
    "        training_pair = training_pairs[iter - 1]\n",
    "        input_variable = training_pair[0]\n",
    "        target_variable = training_pair[1]\n",
    "\n",
    "        loss = train(input_variable, target_variable, encoder,\n",
    "                     decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
    "        print_loss_total += loss\n",
    "        plot_loss_total += loss\n",
    "\n",
    "        if iter % print_every == 0:\n",
    "            print_loss_avg = print_loss_total / print_every\n",
    "            print_loss_total = 0\n",
    "            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\n",
    "                                         iter, iter / n_iters * 100, print_loss_avg))\n",
    "\n",
    "        if iter % plot_every == 0:\n",
    "            plot_loss_avg = plot_loss_total / plot_every\n",
    "            plot_losses.append(plot_loss_avg)\n",
    "            plot_loss_total = 0\n",
    "\n",
    "    showPlot(plot_losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 绘制结果\n",
    "\n",
    "\n",
    "使用matplotlib完成绘图, 使用训练时保存的损失值``plot_losses``数组.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def showPlot(points):\n",
    "    plt.figure()\n",
    "    fig, ax = plt.subplots()\n",
    "    # 这个定位器会定期发出提示信息\n",
    "    loc = ticker.MultipleLocator(base=0.2)\n",
    "    ax.yaxis.set_major_locator(loc)\n",
    "    plt.plot(points)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估\n",
    "\n",
    "\n",
    "评估与训练大部分相同,但没有目标,因此我们只是将解码器的每一步预测反馈给它自身.\n",
    "每当它预测到一个单词时,我们就会将它添加到输出字符串中,并且如果它预测到我们在那里停止的EOS指令.\n",
    "我们还存储解码器的注意力输出以供稍后显示.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n",
    "    input_variable = variableFromSentence(input_lang, sentence)\n",
    "    input_length = input_variable.size()[0]\n",
    "    encoder_hidden = encoder.initHidden()\n",
    "\n",
    "    encoder_outputs = Variable(torch.zeros(max_length, encoder.hidden_size))\n",
    "    encoder_outputs = encoder_outputs.cuda() if use_cuda else encoder_outputs\n",
    "\n",
    "    for ei in range(input_length):\n",
    "        encoder_output, encoder_hidden = encoder(input_variable[ei],\n",
    "                                                 encoder_hidden)\n",
    "        encoder_outputs[ei] = encoder_outputs[ei] + encoder_output[0][0]\n",
    "\n",
    "    decoder_input = Variable(torch.LongTensor([[SOS_token]]))  # SOS\n",
    "    decoder_input = decoder_input.cuda() if use_cuda else decoder_input\n",
    "\n",
    "    decoder_hidden = encoder_hidden\n",
    "\n",
    "    decoded_words = []\n",
    "    decoder_attentions = torch.zeros(max_length, max_length)\n",
    "\n",
    "    for di in range(max_length):\n",
    "        decoder_output, decoder_hidden, decoder_attention = decoder(\n",
    "            decoder_input, decoder_hidden, encoder_outputs)\n",
    "        decoder_attentions[di] = decoder_attention.data\n",
    "        topv, topi = decoder_output.data.topk(1)\n",
    "        ni = topi[0][0]\n",
    "        if ni == EOS_token:\n",
    "            decoded_words.append('<EOS>')\n",
    "            break\n",
    "        else:\n",
    "            decoded_words.append(output_lang.index2word[ni])\n",
    "\n",
    "        decoder_input = Variable(torch.LongTensor([[ni]]))\n",
    "        decoder_input = decoder_input.cuda() if use_cuda else decoder_input\n",
    "\n",
    "    return decoded_words, decoder_attentions[:di + 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以从训练集中评估随机的句子并打印出输入,目标和输出以作出一些主观质量判断:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluateRandomly(encoder, decoder, n=10):\n",
    "    for i in range(n):\n",
    "        pair = random.choice(pairs)\n",
    "        print('>', pair[0])\n",
    "        print('=', pair[1])\n",
    "        output_words, attentions = evaluate(encoder, decoder, pair[0])\n",
    "        output_sentence = ' '.join(output_words)\n",
    "        print('<', output_sentence)\n",
    "        print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练和评估\n",
    "\n",
    "有了所有这些辅助功能(它看起来像是额外的工作,但它使运行多个实验更容易),\n",
    "我们就立马可以初始化网络并开始培训.\n",
    "\n",
    "请记住输入句子被严重过滤, 对于这个小数据集,我们可以使用包含256个隐藏节点\n",
    "和单个GRU层的相对较小的网络.在MacBook CPU上约40分钟后,我们会得到一些合理的结果.\n",
    "\n",
    "PS:\n",
    "   如果你运行这个notebook,你可以训练,打断内核,评估并在以后继续训练. \n",
    "   注释编码器和解码器初始化的行并再次运行 ``trainIters`` .\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size = 256\n",
    "encoder1 = EncoderRNN(input_lang.n_words, hidden_size)\n",
    "attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1)\n",
    "\n",
    "\n",
    "if use_cuda:\n",
    "    encoder1 = encoder1.cuda()\n",
    "    attn_decoder1 = attn_decoder1.cuda()\n",
    "\n",
    "trainIters(encoder1, attn_decoder1, 75000, print_every=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluateRandomly(attn_decoder1, encoder1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可视化注意力\n",
    "---------------------\n",
    "\n",
    "注意力机制的一个有用特性是其高度可解释的输出. \n",
    "由于它用于对输入序列的特定编码器输出进行加权,因此我们可以想象在每个时间步骤中查看网络最关注的位置.\n",
    "\n",
    "您可以简单地运行 ``plt.matshow(attentions)``,将注意力输出显示为矩阵,\n",
    "其中列是输入步骤,行是输出步骤.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_words, attentions = evaluate(\n",
    "    encoder1, attn_decoder1, \"je suis trop froid .\")\n",
    "plt.matshow(attentions.numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了获得更好的观看体验,我们将额外添加轴和标签:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def showAttention(input_sentence, output_words, attentions):\n",
    "    # 用颜色条设置图形\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111)\n",
    "    cax = ax.matshow(attentions.numpy(), cmap='bone')\n",
    "    fig.colorbar(cax)\n",
    "\n",
    "    # 设置轴\n",
    "    ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
    "                       ['<EOS>'], rotation=90)\n",
    "    ax.set_yticklabels([''] + output_words)\n",
    "\n",
    "    # 在每个打勾处显示标签\n",
    "    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def evaluateAndShowAttention(input_sentence):\n",
    "    output_words, attentions = evaluate(\n",
    "        encoder1, attn_decoder1, input_sentence)\n",
    "    print('input =', input_sentence)\n",
    "    print('output =', ' '.join(output_words))\n",
    "    showAttention(input_sentence, output_words, attentions)\n",
    "\n",
    "\n",
    "evaluateAndShowAttention(\"elle a cinq ans de moins que moi .\")\n",
    "\n",
    "evaluateAndShowAttention(\"elle est trop petit .\")\n",
    "\n",
    "evaluateAndShowAttention(\"je ne crains pas de mourir .\")\n",
    "\n",
    "evaluateAndShowAttention(\"c est un jeune directeur plein de talent .\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "练习\n",
    "=========\n",
    "\n",
    "-  尝试使用不同的数据集\n",
    "\n",
    "   -  另一种语言对\n",
    "   -  人 → 机器 (例如. IOT 命令)\n",
    "   -  聊天 → 响应\n",
    "   -  问题 → 回答\n",
    "\n",
    "-  用预先训练的词嵌入替换嵌入,例如word2vec或GloVe\n",
    "-  尝试更多图层,更多隐藏单位和更多句子. 比较训练时间和结果.\n",
    "-  如果您使用的翻译文件对中有两个相同的短语(``I am test \\t I am test``),\n",
    "   您可以使用它作为自动编码器.尝试这个:\n",
    "   -  训练自编码器\n",
    "   -  只保存编码器网络\n",
    "   -  从那里训练一个新的解码器进行翻译\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}